{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Pommerman Demo.\n",
    "\n",
    "This notebook demonstrates how to train Pommerman agents. Please let us know at support@pommerman.com if you run into any issues."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import numpy as np\n",
    "\n",
    "from pommerman.agents import SimpleAgent, RandomAgent, PlayerAgent, BaseAgent\n",
    "from pommerman.configs import ffa_v0_fast_env\n",
    "from pommerman.envs.v0 import Pomme\n",
    "from pommerman.characters import Bomber\n",
    "from pommerman import utility"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Random agents\n",
    "\n",
    "The following codes instantiates the environment with four random agents who take actions until the game is finished. (This will be a quick game.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "# Instantiate the environment\n",
    "config = ffa_v0_fast_env()\n",
    "env = Pomme(**config[\"env_kwargs\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Add four random agents\n",
    "agents = {}\n",
    "for agent_id in range(4):\n",
    "    agents[agent_id] = RandomAgent(config[\"agent\"](agent_id, config[\"game_type\"]))\n",
    "env.set_agents(list(agents.values()))\n",
    "env.set_init_game_state(None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'result': <Result.Win: 0>, 'winners': [3]}\n"
     ]
    }
   ],
   "source": [
    "# Seed and reset the environment\n",
    "env.seed(0)\n",
    "obs = env.reset()\n",
    "\n",
    "# Run the random agents until we're done\n",
    "done = False\n",
    "while not done:\n",
    "    env.render()\n",
    "    actions = env.act(obs)\n",
    "    obs, reward, done, info = env.step(actions)\n",
    "env.render(close=True)\n",
    "env.close()\n",
    "\n",
    "print(info)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Human Agents\n",
    "\n",
    "The following code runs the environment with 3 random agents and one agent with human input (use the arrow keys on your keyboard). This can also be called on the command line with:\n",
    "\n",
    "`python run_battle.py --agents=player::arrows,random::null,random::null,random::null --config=PommeFFACompetition-v0`\n",
    "\n",
    "You can also run this with SimpleAgents by executing:\n",
    "\n",
    "`python run_battle.py --agents=player::arrows,test::agents.SimpleAgent,test::agents.SimpleAgent,test::agents.SimpleAgent --config=PommeFFACompetition-v0`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "# Instantiate the environment\n",
    "config = ffa_v0_fast_env()\n",
    "env = Pomme(**config[\"env_kwargs\"])\n",
    "\n",
    "# Add 3 random agents\n",
    "agents = {}\n",
    "for agent_id in range(3):\n",
    "    agents[agent_id] = RandomAgent(config[\"agent\"](agent_id, config[\"game_type\"]))\n",
    "\n",
    "# Add human agent\n",
    "agents[3] = PlayerAgent(config[\"agent\"](agent_id, config[\"game_type\"]), \"arrows\")\n",
    "\n",
    "env.set_agents(list(agents.values()))\n",
    "env.set_init_game_state(None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'result': <Result.Tie: 2>}\n"
     ]
    }
   ],
   "source": [
    "# Seed and reset the environment\n",
    "env.seed(0)\n",
    "obs = env.reset()\n",
    "\n",
    "# Run the agents until we're done\n",
    "done = False\n",
    "while not done:\n",
    "    env.render()\n",
    "    actions = env.act(obs)\n",
    "    obs, reward, done, info = env.step(actions)\n",
    "env.render(close=True)\n",
    "env.close()\n",
    "\n",
    "# Print the result\n",
    "print(info)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training an Agent\n",
    "\n",
    "The following code uses Tensorforce to train a PPO agent. This is in the train_with_tensorforce.py module as well."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make sure you have tensorforce installed: pip install tensorforce\n",
    "from tensorforce.agents import PPOAgent\n",
    "from tensorforce.execution import Runner\n",
    "from tensorforce.contrib.openai_gym import OpenAIGym"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_np_float(feature):\n",
    "    return np.array(feature).astype(np.float32)\n",
    "\n",
    "def featurize(obs):\n",
    "    board = obs[\"board\"].reshape(-1).astype(np.float32)\n",
    "    bomb_blast_strength = obs[\"bomb_blast_strength\"].reshape(-1).astype(np.float32)\n",
    "    bomb_life = obs[\"bomb_life\"].reshape(-1).astype(np.float32)\n",
    "    position = make_np_float(obs[\"position\"])\n",
    "    ammo = make_np_float([obs[\"ammo\"]])\n",
    "    blast_strength = make_np_float([obs[\"blast_strength\"]])\n",
    "    can_kick = make_np_float([obs[\"can_kick\"]])\n",
    "\n",
    "    teammate = obs[\"teammate\"]\n",
    "    if teammate is not None:\n",
    "        teammate = teammate.value\n",
    "    else:\n",
    "        teammate = -1\n",
    "    teammate = make_np_float([teammate])\n",
    "\n",
    "    enemies = obs[\"enemies\"]\n",
    "    enemies = [e.value for e in enemies]\n",
    "    if len(enemies) < 3:\n",
    "        enemies = enemies + [-1]*(3 - len(enemies))\n",
    "    enemies = make_np_float(enemies)\n",
    "\n",
    "    return np.concatenate((board, bomb_blast_strength, bomb_life, position, ammo, blast_strength, can_kick, teammate, enemies))\n",
    "\n",
    "class TensorforceAgent(BaseAgent):\n",
    "    def act(self, obs, action_space):\n",
    "        pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.\u001b[0m\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n"
     ]
    }
   ],
   "source": [
    "# Instantiate the environment\n",
    "config = ffa_v0_fast_env()\n",
    "env = Pomme(**config[\"env_kwargs\"])\n",
    "env.seed(0)\n",
    "\n",
    "# Create a Proximal Policy Optimization agent\n",
    "agent = PPOAgent(\n",
    "    states=dict(type='float', shape=env.observation_space.shape),\n",
    "    actions=dict(type='int', num_actions=env.action_space.n),\n",
    "    network=[\n",
    "        dict(type='dense', size=64),\n",
    "        dict(type='dense', size=64)\n",
    "    ],\n",
    "    batching_capacity=1000,\n",
    "    step_optimizer=dict(\n",
    "        type='adam',\n",
    "        learning_rate=1e-4\n",
    "    )\n",
    ")\n",
    "\n",
    "# Add 3 random agents\n",
    "agents = []\n",
    "for agent_id in range(3):\n",
    "    agents.append(SimpleAgent(config[\"agent\"](agent_id, config[\"game_type\"])))\n",
    "\n",
    "# Add TensorforceAgent\n",
    "agent_id += 1\n",
    "agents.append(TensorforceAgent(config[\"agent\"](agent_id, config[\"game_type\"])))\n",
    "env.set_agents(agents)\n",
    "env.set_training_agent(agents[-1].agent_id)\n",
    "env.set_init_game_state(None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "class WrappedEnv(OpenAIGym):    \n",
    "    def __init__(self, gym, visualize=False):\n",
    "        self.gym = gym\n",
    "        self.visualize = visualize\n",
    "    \n",
    "    def execute(self, action):\n",
    "        if self.visualize:\n",
    "            self.gym.render()\n",
    "\n",
    "        actions = self.unflatten_action(action=action)\n",
    "            \n",
    "        obs = self.gym.get_observations()\n",
    "        all_actions = self.gym.act(obs)\n",
    "        all_actions.insert(self.gym.training_agent, actions)\n",
    "        state, reward, terminal, _ = self.gym.step(all_actions)\n",
    "        agent_state = featurize(state[self.gym.training_agent])\n",
    "        agent_reward = reward[self.gym.training_agent]\n",
    "        return agent_state, terminal, agent_reward\n",
    "    \n",
    "    def reset(self):\n",
    "        obs = self.gym.reset()\n",
    "        agent_obs = featurize(obs[3])\n",
    "        return agent_obs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Stats:  [-1, -1, -1, -1, -1] [15, 15, 27, 32, 26] [2.0443358421325684, 0.7581827640533447, 1.3421897888183594, 1.6136739253997803, 1.2573180198669434]\n"
     ]
    }
   ],
   "source": [
    "# Instantiate and run the environment for 5 episodes.\n",
    "wrapped_env = WrappedEnv(env, True)\n",
    "runner = Runner(agent=agent, environment=wrapped_env)\n",
    "runner.run(episodes=5, max_episode_timesteps=2000)\n",
    "print(\"Stats: \", runner.episode_rewards, runner.episode_timesteps, runner.episode_times)\n",
    "\n",
    "try:\n",
    "    runner.close()\n",
    "except AttributeError as e:\n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pommerman",
   "language": "python",
   "name": "pommerman"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
